/*
 *    Copyright 2022 The DSMS Authors.
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 */

package com.dsms.common.filter;

import com.dsms.common.constant.ResultCode;
import com.dsms.common.constant.SystemConst;
import com.dsms.common.model.Result;
import com.dsms.common.util.JwtUtil;
import com.dsms.common.util.SystemUtil;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.ExpiredJwtException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.redis.core.HashOperations;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.http.HttpStatus;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;


/**
 * JWT Filter
 * before Spring security FilterChains
 */
@Component
public class JwtFilter extends OncePerRequestFilter {
    private static final int REDIS_TTL_MULTIPLE = 4;
    private static final int REDIS_TEMPORARY_TOKEN_TTL = 15;
    @Autowired
    private RedisTemplate<String, Object> redisTemplate;

    @Autowired
    private JwtUtil jwtUtil;
    @Value("${jwt.config.ttl}")
    private long jwtTTl;

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        String token = request.getHeader("authorization");
        String refreshToken = request.getHeader("refreshToken");
        String url = request.getRequestURI();
        if (!StringUtils.hasText(token) || !StringUtils.hasText(refreshToken)) {
            filterChain.doFilter(request, response);
            return;
        }
        //parse token
        String userId = null;
        String userName = null;
        HashOperations<String, String, Object> hashOperations = redisTemplate.opsForHash();

        //only those request that require renew token needed to check token
        if (requireRenew(url)) {
            try {
                Claims claims = jwtUtil.parseJWT(token);
                userId = claims.getId();
                userName = claims.getSubject();
            } catch (ExpiredJwtException e) {
                String redisToken = String.valueOf(hashOperations.get(refreshToken, "token"));
                if (Objects.equals(redisToken, token)) {
                    //set a short-lived temporary token to deal with high concurrency issues
                    redisTemplate.opsForValue().setIfAbsent(SystemConst.REDIS_TEMP_TOKEN_PREFIX + redisToken, "tempToken", REDIS_TEMPORARY_TOKEN_TTL, TimeUnit.SECONDS);
                    //renew token if expire
                    String newToken = jwtUtil.createJWT(userId, userName, null);
                    redisTemplate.opsForHash().put(refreshToken, "token", newToken);
                    redisTemplate.expire(refreshToken, jwtTTl * REDIS_TTL_MULTIPLE, TimeUnit.MINUTES);
                    response.setHeader("authorization", newToken);
                    response.setHeader("Access-Control-Expose-Headers", "authorization,refreshToken");
                } else {
                    //compare the temporary token, if it does not exist, return not logged in
                    Object tempToken = redisTemplate.opsForValue().get(SystemConst.REDIS_TEMP_TOKEN_PREFIX + token);
                    if (Objects.isNull(tempToken)) {
                        SystemUtil.responseUtils(request, response, Result.normal(HttpStatus.UNAUTHORIZED, ResultCode.LOGIN_TIMEOUT));
                        return;
                    }
                }
            } catch (Exception e) {
                SystemUtil.responseUtils(request, response, Result.normal(HttpStatus.UNAUTHORIZED, ResultCode.LOGIN_TIMEOUT));
                return;
            }
        }

        Object loginUser = hashOperations.get(refreshToken, "loginUser");
        if (ObjectUtils.isEmpty(loginUser)) {
            SystemUtil.responseUtils(request, response, Result.normal(HttpStatus.UNAUTHORIZED, ResultCode.LOGIN_TIMEOUT));
            return;
        }
        UsernamePasswordAuthenticationToken usernamePasswordAuthenticationToken = new UsernamePasswordAuthenticationToken(loginUser, null, null);
        SecurityContextHolder.getContext().setAuthentication(usernamePasswordAuthenticationToken);
        filterChain.doFilter(request, response);
    }

    /*
     * some request should not renew token,such as scheduled task
     * */
    private boolean requireRenew(String url) {
        //TODO which urls should not to renew token need to be subdivided
        List<String> urls = List.of("/api/monitor/", "/api/alertmessage/get_unconfirmed_message_num");
        return urls.stream().noneMatch(url::startsWith);
    }
}
